#!/usr/bin/env python3

"""Tests for mbedtls-rewrite-branch-style.

Needs git >=2.19.
"""

# Copyright The Mbed TLS Contributors
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import contextlib
import os
import subprocess
from typing import Iterator, List, Optional
import unittest



class temp_git_worktree:
    """Run code in a temporary Git worktree."""

    def __init__(
            self,
            main_path: str, temp_path: str,
            branch: Optional[str] = None,
            head: Optional[str] = None
    ) -> None:
        self.main_path = os.path.abspath(main_path)
        self.temp_path = os.path.abspath(temp_path)
        self.branch = branch
        self.head = head

    @property
    def path(self):
        """The absolute path to the worktree."""
        return self.temp_path

    def __enter__(self) -> 'temp_git_worktree':
        cmd = ['git', '-C', self.main_path, 'worktree', 'add', '--quiet']
        if self.branch is None:
            cmd += ['--detach']
        else:
            cmd += ['-b', self.branch]
        cmd += [self.temp_path]
        if self.head is not None:
            cmd += [self.head]
        subprocess.check_call(cmd)
        return self

    def __exit__(self, *_exc) -> None:
        subprocess.check_call(['git', '-C', self.main_path,
                               'worktree', 'remove', '--force',
                               self.temp_path])


class Utilities(unittest.TestCase):

    GIT_DIR = 'clone-for-test'
    ORIGIN_URL = 'https://github.com/Mbed-TLS/mbedtls'
    FETCH_BRANCHES = [
        'features/new-code-style/development',
        'features/new-code-style/mbedtls-2.28',
        'features/new-code-style/test/*',
    ]
    FETCH_TAGS = [
        'mbedtls-2.28.0', # for target branch detection
    ]
    SCRIPT_NAME = 'mbedtls-rewrite-branch-style'

    script = None #type: Optional[str]
    git_main_dir = None #type: Optional[str]


    @classmethod
    def call_git(cls, args: List[str], **kwargs) -> None:
        assert cls.git_main_dir is not None
        subprocess.check_call(['git', '-C', cls.git_main_dir] + args,
                              **kwargs)

    @classmethod
    def check_git(cls, args: List[str], **kwargs) -> bool:
        try:
            cls.call_git(args, **kwargs)
            return True
        except subprocess.CalledProcessError as cpe:
            if cpe.returncode == 1:
                return False
            raise

    @classmethod
    def git_output(cls, args: List[str], **kwargs) -> str:
        assert cls.git_main_dir is not None
        return subprocess.check_output(['git', '-C', cls.git_main_dir] + args,
                                       universal_newlines=True,
                                       **kwargs)

    @classmethod
    def git_line(cls, args: List[str], **kwargs) -> str:
        return cls.git_output(args, **kwargs).rstrip('\n')

    @classmethod
    def git_lines(cls, args: List[str], **kwargs) -> List[str]:
        return cls.git_output(args, **kwargs).splitlines()

    @classmethod
    def git_sha(cls, ref: str) -> Optional[str]:
        """The sha for a Git revision, or None if there is no such revision."""
        try:
            return cls.git_line(['rev-parse', ref])
        except subprocess.CalledProcessError as cpe:
            return None

    @classmethod
    def git_head(cls, path: str) -> Optional[str]:
        """The sha of the head in a Git worktree."""
        output = subprocess.check_output(['git', '-C', path,
                                          'rev-parse', 'HEAD'],
                                         universal_newlines=True)
        return output.rstrip('\n')

    @classmethod
    def git_delete_branch(cls, branch: str) -> None:
        """Delete branch if it exists. Do nothing if it doesn't exist."""
        ref = 'refs/heads/' + branch
        if cls.check_git(['show-ref', '--quiet', '--verify', ref]):
            cls.call_git(['branch', '--quiet', '--delete', '--force', branch])

    @classmethod
    def create_git_clone(cls) -> None:
        """Create the git clone used by the tests."""
        assert cls.git_main_dir is not None
        subprocess.check_call(['git', 'clone',
                               '--no-tags', '--single-branch',
                               '--branch=master',
                               cls.ORIGIN_URL,
                               cls.git_main_dir])

    @classmethod
    def update_git_clone(cls) -> None:
        """Set up the existing git clone with what the tests needs.

        This method is destructive. It may create, update or remove branches,
        tags, worktrees, worktree contents, etc.
        """
        cls.call_git(['fetch', '--quiet', '--force', 'origin'] +
                     ['refs/tags/{name}:refs/tags/{name}'.format(name=name)
                      for name in cls.FETCH_TAGS] +
                     ['refs/heads/{name}:refs/remotes/origin/{name}'.format(name=name)
                      for name in cls.FETCH_BRANCHES])
        # Make sure the main worktree doesn't have a branch checked out that
        # might interfere with some of the tests.
        cls.call_git(['checkout', '--quiet', '--detach'])

    @contextlib.contextmanager
    def temp_worktree(self, **kwargs) -> Iterator[temp_git_worktree]:
        assert self.git_main_dir is not None
        path = os.path.join(os.path.dirname(self.git_main_dir),
                            'worktree-for-test-' + self.id())
        with temp_git_worktree(self.git_main_dir, path, **kwargs) as worktree:
            yield worktree

    def assert_same_history(self,
                            common_ancestor: str,
                            branch1: str, branch2: str) -> None:
        """Assert that branch1 and branch2 consist of common_ancestor followed by similar commits.

        Similar commits are defined as equal according to `git range-diff`:
        they must have the same content and description, but may have different
        metadata (in particular, committer name and date).
        """
        # merge_base = self.git_line(['merge-base', branch1, branch2])
        # wanted_merge_base = self.git_line(['rev-parse', common_ancestor])
        # self.assertEqual(merge_base, wanted_merge_base)
        if not self.check_git(['merge-base', '--is-ancestor',
                               common_ancestor, branch1]):
            self.fail('{} ({}) is not an ancestor of {} ({})'
                      .format(common_ancestor, self.git_sha(common_ancestor),
                              branch1, self.git_sha(branch1)))
        if not self.check_git(['merge-base', '--is-ancestor',
                               common_ancestor, branch2]):
            self.fail('{} ({}) is not an ancestor of {} ({})'
                      .format(common_ancestor, self.git_sha(common_ancestor),
                              branch2, self.git_sha(branch2)))
        range_diff = self.git_lines(['range-diff', '--no-color',
                                     common_ancestor, branch1, branch2])
        for line in range_diff:
            self.assertRegex(line, r'\A *[0-9]+: *[0-9a-f]+ *= ')

    @classmethod
    def setUpClass(cls) -> None:
        my_dir = os.path.dirname(__file__)
        cls.script = os.path.abspath(
            os.getenv('SCRIPT_PATH',
                      os.path.join(my_dir, os.pardir, 'bin', cls.SCRIPT_NAME)))
        cls.git_main_dir = os.path.abspath(
            os.getenv('TMP_GIT_DIR',
                      os.path.join(my_dir, cls.GIT_DIR)))
        assert cls.git_main_dir is not None
        if not os.path.exists(cls.git_main_dir):
            cls.create_git_clone()
        cls.update_git_clone()

    def temp_branch(self, name: str) -> str:
        """A name for a temporary branch.

        temp_branch(name) returns the same branch name when passed the same
        name parameter and called from the same test case, and distinct
        branch names otherwise.

        The branch is guaranteed not to exist when the test starts,
        and will be removed when the test ends.
        """
        return 'tmp/{}/{}'.format(self.id(), name)

    def delete_temporary_branches(self, case_id: str) -> None:
        branches = self.git_lines(['for-each-ref', '--format=%(refname:short)',
                                   'refs/heads/tmp/{}/**'.format(case_id)])
        if branches:
            self.call_git(['branch', '--quiet', '--delete', '--force'] +
                          branches)

    def setUp(self) -> None:
        self.delete_temporary_branches(self.id())

    def tearDown(self) -> None:
        self.delete_temporary_branches(self.id())

    def run_script(
            self,
            args: List[str],
            cwd: Optional[str] = None
    ) -> subprocess.CompletedProcess:
        """Run the script with the given arguments."""
        assert self.script is not None
        if cwd is None:
            cwd = self.git_main_dir
        return subprocess.run([self.script] + args,
                              check=False,
                              cwd=cwd,
                              stdout=subprocess.PIPE, stderr=subprocess.PIPE,
                              universal_newlines=True)

class Smoke(Utilities):

    def test_help(self) -> None:
        cp = self.run_script(['--help'])
        self.assertEqual(cp.returncode, 0)
        self.assertIn('Rewrite', cp.stdout)
        self.assertIn('--verbose', cp.stdout)
        self.assertEqual(cp.stderr, '')

    def test_unknown_option(self) -> None:
        cp = self.run_script(['--barf'])
        self.assertEqual(cp.returncode, 2)
        self.assertEqual(cp.stdout, '')
        self.assertNotEqual(cp.stderr, '')

class Result(Utilities):
    """Test that running the script on a given branch gives the expected result.

    To add a new test case:
    * Pick a branch to test and a name $foo.
    * Run the rewrite script and carefully review the output.
    * Push the original branch, the target branch and the rewritten branch
      to the main mbedtls repository in the features/new-code-style/test
      namespace.

    For example:
    ```
    mbedtls-rewrite-branch-style --onto=origin/features/new-code-style/development --backup-branch=1234-old --new-branch=1234-new
    : review the output
    git push origin 1234-old:refs/heads/features/new-code-style/test/1234/old
    git push origin origin/features/new-code-style/development:refs/heads/features/new-code-style/test/1234/target
    git push origin 1234-new:refs/heads/features/new-code-style/test/1234/new
    ```
    """

    OLD_BRANCH = 'origin/features/new-code-style/test/{}/old'
    TARGET_BRANCH = 'origin/features/new-code-style/test/{}/target'
    GOOD_NEW_BRANCH = 'origin/features/new-code-style/test/{}/new'

    def run_on_branch(self, name) -> None:
        old_branch = self.OLD_BRANCH.format(name)
        target_branch = self.TARGET_BRANCH.format(name)
        target_commit = target_branch + '^{/Switch to the new code style}'
        new_branch = self.temp_branch('new')
        with self.temp_worktree(head=old_branch) as worktree:
            cp = self.run_script(['-v', '--backup-branch=',
                                  '--onto', target_branch,
                                  '--new-branch', new_branch],
                                 cwd=worktree.path)
            self.assertEqual(cp.returncode, 0)
            self.assert_same_history(target_commit,
                                     self.GOOD_NEW_BRANCH.format(name),
                                     new_branch)

    def test_6863(self) -> None:
        self.run_on_branch('6863')

    def test_6882(self) -> None:
        self.run_on_branch('6882')

    def test_6883(self) -> None:
        self.run_on_branch('6883')

    def test_6888(self) -> None:
        self.run_on_branch('6888')

    def test_6889(self) -> None:
        self.run_on_branch('6889')

    def test_6890(self) -> None:
        self.run_on_branch('6890')

    def test_rename_delete(self) -> None:
        self.run_on_branch('rename-delete')

class Detection(Utilities):
    """Test target branch detection."""

    OLD_BRANCH = 'origin/features/new-code-style/test/{}/old'

    def run_on_branch(self, name, expected_target) -> None:
        old_branch = self.OLD_BRANCH.format(name)
        new_branch = self.temp_branch('new')
        with self.temp_worktree(head=old_branch) as worktree:
            cp = self.run_script(['-v', '--backup-branch=',
                                  '--new-branch', new_branch],
                                 cwd=worktree.path)
            self.assertEqual(cp.returncode, 0)
            self.assertRegex(cp.stderr,
                             r'Guessed target branch: (\S+/)?' + expected_target)

    def test_6802(self) -> None:
        self.run_on_branch('6802', 'development')

    def test_6844(self) -> None:
        self.run_on_branch('6844', 'mbedtls-2.28')


class Options(Utilities):
    old_branch = 'origin/features/new-code-style/test/6863/old'
    target_branch = 'origin/features/new-code-style/test/6863/target'
    target_commit = target_branch + '^{/Switch to the new code style}'
    good_new_branch = 'origin/features/new-code-style/test/6863/new'

    def test_update_in_place(self) -> None:
        temp_branch = self.temp_branch('work')
        with self.temp_worktree(branch=temp_branch,
                                head=self.old_branch) as worktree:
            cp = self.run_script(['-v', '--backup-branch=',
                                  '--onto', self.target_branch],
                                 cwd=worktree.path)
            self.assertEqual(cp.returncode, 0)
            self.assert_same_history(self.target_commit,
                                     self.good_new_branch,
                                     temp_branch)


if __name__ == '__main__':
    unittest.main()
